import numpy as np
import torch
import pyro
from torch import Tensor

from scipy.interpolate import griddata, RectBivariateSpline

from src.envs.base_environment import ContinuousEnvironment  


def load_log_reward(num_grid_points):
    """
    Loads the free energy surface and returns the log of the reward density.
    """
    fes_file = './src/data/fes.dat'

    T = 300                        # Temperature in Kelvin used to generate the data
    beta = 1 / (T * 0.0083144621)  # Boltzmann constant in kJ/mol/K

    fe_data = np.genfromtxt(fes_file, autostrip=True)
    phi_traj = fe_data[:, 0]      # load phi coordinates
    psi_traj = fe_data[:, 1]      # load psi coordinates
    fe = fe_data[:, 2]            # load free energy values
    fe = fe - np.min(fe)          # set the minimum free energy to 0

    # wrap data around the periodic boundary
    phi_extended = np.concatenate([phi_traj, phi_traj + 2*np.pi, phi_traj - 2*np.pi, phi_traj - 2*np.pi, phi_traj + 2*np.pi])
    psi_extended = np.concatenate([psi_traj, psi_traj + 2*np.pi, psi_traj - 2*np.pi, psi_traj + 2*np.pi, psi_traj - 2*np.pi])
    fe_extended = np.concatenate([fe, fe, fe, fe, fe])
    phi_linespace = np.linspace(-np.pi, np.pi, num_grid_points)
    psi_linespace = np.linspace(-np.pi, np.pi, num_grid_points)
    PHI, PSI = np.meshgrid(phi_linespace, psi_linespace)

    # interpolate and crop to the original grid
    FE = griddata((psi_extended, phi_extended), fe_extended, (PHI, PSI), method='cubic')

    # calculate the reward density
    boltzmann_weight = np.exp(-beta * FE)
    reward_grid = boltzmann_weight / np.sum(boltzmann_weight)
    reward_density = reward_grid * (num_grid_points / (2 * np.pi)) ** 2
    log_reward_density = np.log(reward_density)
    
    # interpolate the log reward density
    log_reward = RectBivariateSpline(phi_linespace, psi_linespace, log_reward_density)

    return log_reward


class AlanineDipeptideEnvironment(ContinuousEnvironment):
    """
    ### Description
    
    The Alanine Dipeptide environment is a 2D environment where the state is a 2D vector representing the phi and psi angles of the alanine dipeptide molecule. 
    The action is a 2D vector representing the change in the phi and psi angles. The goal is to sample the free energy surface of the alanine dipeptide molecule by applying actions to the angles. 
    The reward density is defined as the Boltzmann weight of the free energy surface.

    ### Action Space

    | Num | Action    | Min | Max|
    |-----|-----------|-----|----|
    | 0   | Delta Phi | -pi | pi |
    | 1   | Delta Psi | -pi | pi |
    

    ### Observation Space

    | Num | Observation | Min | Max |
    |-----|-------------|-----|-----|
    | 0   | Phi         | -pi | pi  |
    | 1   | Psi         | -pi | pi  |

    ### Rewards

    The reward density is defined as the Boltzmann weight of the free energy surface.

    r = exp(-beta * F(x)) / Z

    where F(x) is the free energy surface, beta is the inverse temperature, and Z is the normalisation constant.

    ### Policy Parameterisation

    The policy is parameterised as a mixture model with `mixture_dim` components.
    The mixture is a bivariate von Mises distribution for the phi and psi coordinates.

    ### Arguments

    - `max_log_concentration`: Maximum log concentration parameter for the von Mises distribution.
    - `min_log_concentration`: Minimum log concentration parameter for the von Mises distribution.
    - `num_grid_points`: Number of grid points in each dimension of the state space.
    - `mixture_dim`: Number of components in the von Mises mixture model in the parameterisation of the policy.
    """
    
    def __init__(
            self, 
            config):
        self._init_required_params(config)
        lower_bound = torch.tensor([-np.pi, -np.pi], device=config["device"])
        upper_bound = torch.tensor([np.pi, np.pi], device=config["device"])
        self.log_reward_function  = load_log_reward(config["env"]["num_grid_points"][0])  
        super().__init__(config,
                         dim = 2,
                         feature_dim = 2,
                         angle_dim = [True, True],
                         action_dim = 2,
                         lower_bound = lower_bound,
                         upper_bound = upper_bound,
                         mixture_dim = config["env"]["mixture_dim"],
                         output_dim = 5 * config["env"]["mixture_dim"])  

    def _init_required_params(self, config):
        required_params = ["max_log_concentration", "min_log_concentration"]
        assert all([param in config["env"] for param in required_params]), f"Missing required parameters: {required_params}"
        self.max_log_concentration = config["env"]["max_log_concentration"]
        self.min_log_concentration = config["env"]["min_log_concentration"]

    def log_reward(self, x):
        return torch.tensor(self.log_reward_function(x = x[..., 0], y = x[..., 1], grid = False), device=self.device)
    
    def step(self, x: Tensor, action: Tensor):
        """Takes a step in the environment given an action. x is the current state and action is the action to take. Returns the new state."""
        # x: [batch_size, 2]
        # action: [batch_size]
        new_x = torch.zeros_like(x)
        new_x[:, 0] = x[:, 0] + action[:, 0]  # Update phi coordinate
        new_x[:, 1] = x[:, 1] + action[:, 1]  # Update psi coordinate
        new_x[:, 2] = x[:, 2] + 1             # Increment step counter

        # Wrap angles back into the range (-pi, pi)
        new_x[:, 0] = (new_x[:, 0] + np.pi) % (2 * np.pi) - np.pi
        new_x[:, 1] = (new_x[:, 1] + np.pi) % (2 * np.pi) - np.pi

        return new_x
    
    def backward_step(self, x: Tensor, action: Tensor):
        """Takes a backward step in the environment given an action. x is the current state and action that had been taken to reach x. Returns the previous state."""
        # x: [batch_size, 2]
        # action: [batch_size]
        new_x = torch.zeros_like(x)
        new_x[:, 0] = x[:, 0] - action[:, 0] # Update phi coordinate
        new_x[:, 1] = x[:, 1] - action[:, 1] # Update psi coordinate
        new_x[:, 2] = x[:, 2] - 1            # Decrement step counter

        # Wrap angles back into the range (-pi, pi)
        new_x[:, 0] = (new_x[:, 0] + np.pi) % (2 * np.pi) - np.pi
        new_x[:, 1] = (new_x[:, 1] + np.pi) % (2 * np.pi) - np.pi

        return new_x
    
    def compute_initial_action(self, first_state):
        return (first_state - self.init_value)
    
    def _init_policy_dist(self, param_dict):
        """Initialises a mixture of von Mises distributions. Used for policy parameterisation."""
        mus_phi, mus_psi, concs_phi, concs_psi, weights = param_dict["mus_phi"], param_dict["mus_psi"], param_dict["concs_phi"], param_dict["concs_psi"], param_dict["weights"]
        # we set the correlation parameter of the bivariate von mises distribution to 0 to model protein backbone angles 
        # see e.g. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3118414/
        if self.mixture_dim == 1:
            mus_phi = mus_phi.squeeze(-1)
            mus_psi = mus_psi.squeeze(-1)
            concs_phi = concs_phi.squeeze(-1)
            concs_psi = concs_psi.squeeze(-1)
            sbvm = pyro.distributions.SineBivariateVonMises(mus_phi, mus_psi, concs_phi, concs_psi, correlation=0)
            sbvm.max_sample_iter = 25000
            return sbvm
        elif self.mixture_dim > 1:
            mix = torch.distributions.Categorical(weights)
            sbvm = pyro.distributions.SineBivariateVonMises(mus_phi, mus_psi, concs_phi, concs_psi, correlation=0)
            sbvm.max_sample_iter = 25000
            comp = torch.distributions.Independent(sbvm, 0)
            return torch.distributions.MixtureSameFamily(mix, comp)
    
    def postprocess_params(self, params):
        """Postprocesses the parameters of the policy distribution to ensure they are within the correct range(s)."""
        # Restrict mu_x and mu_y to the range (-pi, pi)
        mu_phi_params, mu_psi_params, conc_phi_params, conc_psi_params, weight_params = params[:, :self.mixture_dim], params[:, self.mixture_dim: 2 * self.mixture_dim], params[:, 2 * self.mixture_dim: 3 * self.mixture_dim], params[:, 3 * self.mixture_dim: 4 * self.mixture_dim], params[:, 4 * self.mixture_dim:]

        mus_phi = np.pi * torch.atan(mu_phi_params) * 2
        mus_psi = np.pi * torch.atan(mu_psi_params) * 2

        concs_phi = torch.exp(torch.sigmoid(conc_phi_params) * (self.max_log_concentration - self.min_log_concentration) + self.min_log_concentration)
        concs_psi = torch.exp(torch.sigmoid(conc_psi_params) * (self.max_log_concentration - self.min_log_concentration) + self.min_log_concentration)

        weights = torch.softmax(weight_params, dim=1)
        param_dict = {"mus_phi": mus_phi, "mus_psi": mus_psi, "concs_phi": concs_phi, "concs_psi": concs_psi, "weights": weights}
        
        return param_dict
    
    def add_noise(self, param_dict: dict, off_policy_noise: float):
        """Adds noise to the policy parameters for noisy exploration."""
        concs_phi, concs_psi = param_dict["concs_phi"], param_dict["concs_psi"]
        sigma_phi = torch.sqrt(1/concs_phi)
        sigma_psi = torch.sqrt(1/concs_psi)
        exploration_sigma_phi = sigma_phi + off_policy_noise
        exploration_sigma_psi = sigma_psi + off_policy_noise
        param_dict["concs_phi"] = 1 / (exploration_sigma_phi ** 2)
        param_dict["concs_psi"] = 1 / (exploration_sigma_psi ** 2)

        return param_dict
    